#  Copyright 2023 Alibaba, Inc. or its affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#       https://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from __future__ import absolute_import

import logging
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Callable, Optional

from pai.api.base import PaginatedResult
from pai.exception import PAIException
from pai.pipeline.artifact import ArchivedArtifact
from pai.session import Session, get_default_session

logger = logging.getLogger(__name__)


# TODO: review the status names of the PipelineRun.
class PipelineRunStatus(object):
    Initialized = "Initialized"
    ReadyToSchedule = "ReadyToSchedule"
    Starting = "Starting"
    Running = "Running"
    WorkflowServiceStarting = "WorkflowServiceStarting"
    Suspended = "Suspended"
    Succeeded = "Succeeded"
    Terminated = "Terminated"
    Unknown = "Unknown"
    Skipped = "Skipped"
    Failed = "Failed"

    @classmethod
    def completed_status(cls):
        return [
            cls.Suspended,
            cls.Terminated,
            cls.Skipped,
            cls.Failed,
        ]

    @classmethod
    def is_running(cls, status):
        if status in (
            cls.Starting,
            cls.Running,
            cls.WorkflowServiceStarting,
            cls.ReadyToSchedule,
        ):
            return True
        return False


class PipelineRun(object):
    """Class represent a pipeline run resource."""

    def __init__(
        self,
        run_id,
        name=None,
        workspace_id=None,
        status=None,
        node_id=None,
        duration=None,
        started_at=None,
        finished_at=None,
        source=None,
        user_id=None,
        parent_user_id=None,
        session=None,
    ):
        self.run_id = run_id
        self.name = name
        self.workspace_id = workspace_id
        self.status = status
        self.node_id = node_id
        self.duration = duration
        self.started_at = started_at
        self.finished_at = finished_at
        self.source = source
        self.user_id = user_id
        self.parent_user_id = parent_user_id
        self.session = session or get_default_session()

    @classmethod
    def get(cls, run_id, session=None) -> "PipelineRun":
        session = session or get_default_session()
        return cls.deserialize(session.pipeline_run_api.get(run_id=run_id))

    @classmethod
    def run(
        cls,
        name,
        arguments,
        env=None,
        pipeline_id: Optional[str] = None,
        manifest: Optional[str] = None,
        no_confirm_required: bool = True,
        session: Optional[Session] = None,
    ):
        """Submit a pipeline run with pipeline operator and run arguments.

        If pipeline_id is supplied, remote pipeline manifest is used as workflow template.


        Args:
            name (str): PipelineRun instance name of the submitted job.
            arguments (dict): Run arguments required by pipeline manifest.
            env (list): Environment arguments of run.
            pipeline_id (str): Pipeline
            manifest (str or dict): Pipeline manifest of the run workflow.
            no_confirm_required (bool): Run workflow start immediately if true
                else start_run service call if required to start the workflow.
            session (:class:`pai.session.Session`): A PAI session instance used for
                communicating with PAI service.

        Returns:
            str:run id if run workflow init success.

        """
        session = session or get_default_session()
        run_id = session.pipeline_run_api.create(
            name=name,
            arguments=arguments,
            env=env,
            manifest=manifest,
            pipeline_id=pipeline_id,
            no_confirm_required=no_confirm_required,
        )

        run = PipelineRun.get(run_id)
        logger.info(
            "Create pipeline run succeeded (run_id: {run_id}), please visit the link"
            " below to view the run details.".format(run_id=run_id)
        )
        logger.info(run.console_uri)
        return run_id

    @classmethod
    def list(
        cls,
        name=None,
        run_id=None,
        pipeline_id=None,
        status=None,
        sort_by=None,
        order=None,
        page_size=20,
        page_number=1,
        session=None,
        **kwargs,
    ):
        session = session or get_default_session()
        result = session.pipeline_run_api.list(
            name=name,
            run_id=run_id,
            pipeline_id=pipeline_id,
            status=status,
            sort_by=sort_by,
            order=order,
            workspace_id=None,
            page_size=page_size,
            page_number=page_number,
            **kwargs,
        )

        return [cls.deserialize(run) for run in result.items]

    @classmethod
    def deserialize(cls, d):
        return cls(
            run_id=d["RunId"],
            node_id=d["NodeId"],
            name=d["Name"],
            workspace_id=d["WorkspaceId"],
            user_id=d.get("UserId"),
            parent_user_id=d.get("ParentUserId"),
            source=d.get("Source"),
            started_at=d.get("StartedAt"),
            status=d.get("Status"),
        )

    def __repr__(self):
        return "PipelineRun:%s" % self.run_id

    def travel_node_status_info(self, node_id, max_depth=10):
        node_status_info = dict()

        def pipelines_travel(curr_node_id, parent=None, cur_depth=1):
            if cur_depth > max_depth:
                return
            run_node_detail_info = self.session.pipeline_run_api.get_node(
                self.run_id,
                curr_node_id,
                depth=2,
            )

            if (
                not run_node_detail_info
                or "StartedAt" not in run_node_detail_info["StatusInfo"]
            ):
                return

            if parent is None:
                curr_root_name = self.name
            else:
                curr_root_name = "{0}.{1}".format(
                    run_node_detail_info["Metadata"]["Name"], parent
                )
            node_status_info[curr_root_name] = self._pipeline_node_info(
                run_node_detail_info
            )

            pipelines = run_node_detail_info["Spec"].get("Pipelines", [])
            if not pipelines:
                return
            for sub_pipeline in pipelines:
                node_name = "{0}.{1}".format(
                    curr_root_name, sub_pipeline["Metadata"]["Name"]
                )
                node_status_info[node_name] = self._pipeline_node_info(sub_pipeline)
                next_node_id = sub_pipeline["Metadata"]["NodeId"]
                if sub_pipeline["Metadata"]["NodeType"] == "Dag" and next_node_id:
                    pipelines_travel(next_node_id, curr_root_name, cur_depth + 1)

        pipelines_travel(curr_node_id=node_id)
        return node_status_info

    @staticmethod
    def _pipeline_node_info(pipeline_info):
        return {
            "name": pipeline_info["Metadata"]["Name"],
            "nodeId": pipeline_info["Metadata"]["NodeId"],
            "status": pipeline_info["StatusInfo"]["Status"],
            "startedAt": pipeline_info["StatusInfo"]["StartedAt"],
            "finishedAt": pipeline_info["StatusInfo"].get("FinishedAt", None),
        }

    @property
    def console_uri(self):
        if not self.session.is_inner:
            return "{console_host}?regionId={region_id}#/studio/task/detail/{run_id}".format(
                console_host=self.session.console_uri,
                region_id=self.session.region_id,
                run_id=self.run_id,
            )
        return "{console_host}/#/studio/task/detail/{run_id}".format(
            console_host=self.session.console_uri, run_id=self.run_id
        )

    def get_run_info(self):
        return self.session.pipeline_run_api.get(self.run_id)

    def get_run_node_detail(self, node_id, depth=2):
        return self.session.pipeline_run_api.get_node(
            self.run_id, node_id=node_id, depth=depth
        )

    def get_outputs(self, name=None, node_id=None, depth=1, type=None):
        if not node_id:
            run_info = self.get_run_info()
            node_id = run_info["NodeId"]

        if not node_id:
            return

        result = self.session.pipeline_run_api.list_node_outputs(
            name=name,
            node_id=node_id,
            run_id=self.run_id,
            depth=depth,
            type=type,
        )
        return [ArchivedArtifact.deserialize(output) for output in result.items]

    def get_status(self):
        return self.get_run_info()["Status"]

    def start(self):
        self.session.pipeline_run_api.start(self.run_id)

    def terminate(self):
        self.session.pipeline_run_api.terminate(self.run_id)

    def _wait_for_init(self, retry_interval=1):
        """Wait for "NodeId" allocated to pipeline run."""
        datetime.now()
        run_info = self.get_run_info()
        while (
            PipelineRunStatus.is_running(run_info["Status"]) and not run_info["NodeId"]
        ):
            time.sleep(retry_interval)
            run_info = self.get_run_info()

        if run_info.get("NodeId", None):
            return run_info["NodeId"]
        else:
            raise ValueError("Failed in acquire root node_id of pipeline run.")

    def wait_for_completion(self, show_outputs=True):
        """Wait until the pipeline run stop."""
        run_info = self.get_run_info()
        node_id = run_info["NodeId"]
        if not node_id:
            raise ValueError("Expect NodeId in GetRun response")

        run_status = run_info["Status"]
        if run_status == PipelineRunStatus.Initialized:
            raise ValueError(
                'Pipeline run instance is in status "Init", please start the run instance.'
            )
        elif run_status in (PipelineRunStatus.Terminated, PipelineRunStatus.Suspended):
            raise ValueError(
                "Pipeline run instance is stopped(status:%s), please resume/retry the run."
                % run_status
            )
        elif run_status == PipelineRunStatus.Failed:
            raise ValueError("Pipeline run is failed.")
        elif run_status in (PipelineRunStatus.Skipped, PipelineRunStatus.Unknown):
            raise ValueError(
                "Pipeline run in unexpected status(%s:%s)" % (self.run_id, run_status)
            )
        elif run_status == PipelineRunStatus.Succeeded:
            return

        # Wait for Workflow init.
        print("Wait for run workflow init")

        if show_outputs:
            run_logger = _RunLogger(
                run_instance=self, node_id=node_id, session=self.session
            )
        else:
            run_logger = _MockRunLogger(run_instance=self, node_id=node_id)

        try:
            prev_status_infos = {}
            root_node_status = run_status
            log_runners = []
            while PipelineRunStatus.is_running(root_node_status):
                curr_status_infos = self.travel_node_status_info(node_id)
                for node_fullname, status_info in curr_status_infos.items():
                    if (
                        node_fullname not in prev_status_infos
                        and status_info["status"] != PipelineRunStatus.Skipped
                    ):
                        log_runner = run_logger.submit(
                            node_id=status_info["nodeId"], node_name=node_fullname
                        )
                        if log_runner:
                            log_runners.append(log_runner)
                prev_status_infos = curr_status_infos
                root_node_status = (
                    curr_status_infos[self.name]["status"]
                    if self.name in curr_status_infos
                    else root_node_status
                )

                if root_node_status == PipelineRunStatus.Failed:
                    raise PAIException(
                        "PipelineRun failed: run_id={}, run_status_info={}".format(
                            self.run_id, curr_status_infos
                        )
                    )
                failed_nodes = {
                    name: status_info
                    for name, status_info in curr_status_infos.items()
                    if PipelineRunStatus.Failed == status_info["status"]
                }
                if failed_nodes:
                    raise PAIException(
                        "PipelineRun failed: run_id={}, failed_nodes={}".format(
                            self.run_id, failed_nodes
                        )
                    )

                time.sleep(2)
        except (KeyboardInterrupt, PAIException) as e:
            run_logger.stop_tail()
            raise e

        for log_runner in log_runners:
            _ = log_runner.result()

        return self

    def _wait_with_progress(self):
        pass

    def _wait_with_logger(self, node_id):
        pass


def make_log_iterator(method: Callable, **kwargs):
    """Make an iterator from resource list API.

    Args:
        method: Resource List API.
        **kwargs: arguments for the method.

    Returns:
        A resource iterator.
    """

    page_offset = kwargs.get("page_offset", 0)
    page_size = kwargs.get("page_size", 20)

    while True:
        kwargs.update(page_offset=page_offset, page_size=page_size)
        result: PaginatedResult = method(**kwargs)

        for item in result.items:
            yield item

        if len(result.items) == 0 or len(result.items) <= page_size:
            return
        page_offset += page_size


class _RunLogger(object):
    executor = ThreadPoolExecutor(5)

    def __init__(self, run_instance, node_id, session):
        super(_RunLogger, self).__init__()
        self.run_instance = run_instance
        self.node_id = node_id
        self.running_nodes = set()
        self.session = session
        self._tail = True

    def tail(
        self,
        node_id,
        node_name,
        page_size=100,
        page_offset=0,
    ):
        if node_id in self.running_nodes:
            return
        self.running_nodes.add(node_id)

        while True and self._tail:
            logs = make_log_iterator(
                self.session.pipeline_run_api.list_node_logs,
                run_id=self.run_instance.run_id,
                node_id=node_id,
                page_size=page_size,
                page_offset=page_offset,
            )

            count = 0
            for log in logs:
                print("%s: %s" % (node_name, log))
                page_offset += 1
                count += 1
                if count % page_size == 0:
                    time.sleep(0.5)

            if count == 0:
                status = self.run_instance.get_status()
                if PipelineRunStatus.is_running(status):
                    time.sleep(2)
                else:
                    break

    def submit(
        self,
        node_id,
        node_name,
        page_size=100,
        page_offset=0,
    ):
        print("Add Node Logger: {}, {}".format(node_name, node_id))
        if node_id in self.running_nodes:
            return
        return self.executor.submit(
            self.tail,
            node_id=node_id,
            node_name=node_name,
            page_size=page_size,
            page_offset=page_offset,
        )

    def stop_tail(self):
        self._tail = False


class _MockRunLogger(object):
    def __init__(self, run_instance, node_id):
        super(_MockRunLogger, self).__init__()
        self.run_instance = run_instance
        self.node_id = node_id

    def tail(self, **kwargs):
        pass

    def submit(self, *args, **kwargs):
        pass

    def stop_tail(self):
        pass
